from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from io import BytesIO
import json
import logging
import base64
import threading
import random
import numpy as np
from typing import Callable, List, Tuple, Union
from PIL import Image
import torch.utils.data as data

from image_synthesis.data.utils.tsv_file import TSVFile, CompositeTSVFile
from image_synthesis.utils.misc import instantiate_from_config
from image_synthesis.data.utils.util import generate_stroke_mask

def pre_fetch(tsv_filename: str):
    logging.info('Pre-loading %s ...' % tsv_filename)
    with open(tsv_filename, 'r'):
        logging.info('Pre-loading %s ended.' % tsv_filename)


class TSVDataset(data.Dataset):

    def __init__(self,
                 tsv_file: Union[str, List[str]],
                 transform: Callable = None,
                 map_file: str = None):
        # TODO: 
        raise NotImplementedError('This class is not tested, please test it first!')
        self.transform = transform
        self._chunk_sizes = None
        self.label2idx = self._load_map(map_file)
        self.class_selector = list(self.label2idx.keys()) if self.label2idx else None

        if isinstance(tsv_file, str):
            if os.path.splitext(tsv_file)[1] == '.tsv':
                self.tsv_file = TSVFile(
                    tsv_file, class_selector=self.class_selector
                )
            else:
                self.tsv_file = CompositeTSVFile(
                    tsv_file, class_selector=self.class_selector
                )
                self._chunk_sizes = self.tsv_file.get_chunk_size()
        elif isinstance(tsv_file, list):
            self.tsv_file = CompositeTSVFile(
                tsv_file, class_selector=self.class_selector
            )
            self._chunk_sizes = self.tsv_file.get_chunk_size()
        else:
            raise ValueError("Invalid input! Please check the tsv filenames")

        logging.debug('=> {}\titems: {}'.format(tsv_file, len(self.tsv_file)))

    def num_classes(self):
        return len(self.class_selector)

    def get_chunk_sizes(self):
        return self._chunk_sizes

    def get_class_boundaries(self):
        # The samples of each class are organized class-by-class.
        # _class_boundaries stores the lower- and upper-bound of each class.
        return self.tsv_file.get_class_boundaries()

    def get_filenames(self):
        filenames = [
            self.tsv_file.get_key(i)
            for i in range(self.tsv_file.num_rows())
        ]

        return filenames

    def _load_map(self, map_file: str):
        if not map_file:
            return None

        label2idx = {}
        with open(map_file) as f:
            for line in f:
                items = line.strip().split('\t')
                label2idx[items[0]] = int(items[1])

        return label2idx

    def __getitem__(self, index: Union[int, Tuple[int, int]]):
        if isinstance(index, tuple):
            items = self.tsv_file[index[0]]
            if index[1] >= 0:
                tsv_filename = self.tsv_file.file_list[index[1]]
                x = threading.Thread(target=pre_fetch, args=(tsv_filename,), daemon=True)
                x.start()
        else:
            items = self.tsv_file[index]
        _, target, img = self._decode_data(items)

        if self.transform:
            img = self.transform(img)

        return img, target

    def _decode_data(self, items: Tuple[str, str, str]):
        key = items[0]
        label = self._get_label(items[1])
        image = Image.open(BytesIO(base64.b64decode(items[2]))).convert('RGB')

        return key, label, image

    def _get_label(self, item: str):
        if not self.label2idx:
            return int(item)

        js = json.loads(item)
        return self.label2idx[js[0]['class']]

    def __len__(self):
        return len(self.tsv_file)


class TSVImageTextDataset(data.Dataset):
    """
        This class is intended for encapsulating Image/Text pair data for contrastive learning described in
        the following paper,
        "Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP)
    """
    def __init__(self,
                 name,
                 image_tsv_file: Union[str, List[str]],
                 text_tsv_file: Union[str, List[str]],
                 data_root='',
                 num_captions=1,
                 text_format='txt',
                 filter_texts=None,
                 im_preprocessor_config={
                     'target': 'image_synthesis.data.utils.image_preprocessor.SimplePreprocessor',
                     'params':{
                        'size': 256,
                        'random_crop': False,
                        'horizon_flip': False,
                        }
                 },
                 text_preprocessor_config=None,
                 text_tokenizer_config=None,
                 load_random_mask=False,
                 indices_list_file=None, # if given, load data according to the given indices
                 inferior_size=None, #  (h, w), used for train inpainting model
                 inferior_random_degree=2,
                 mask_low_to_high=-1.0,
                 pixel_kmens_center_path='data/kmeans_centers.npy',
                 return_image256=False
        ):
        
        self.name = name
        self.data_root = 'data' if data_root == '' else data_root
        
        image_tsv_file = [os.path.join(self.data_root, name, tf) for tf in image_tsv_file]
        text_tsv_file = [os.path.join(self.data_root, name, tf) for tf in text_tsv_file]

        self._chunk_sizes = None
        self.num_captions = num_captions
        self.text_format = text_format
        self.load_random_mask = load_random_mask
        self.indices_list_file = indices_list_file
        self.inferior_size = inferior_size
        self.inferior_random_degree = inferior_random_degree
        self.mask_low_to_high = mask_low_to_high
        if self.inferior_size is not None:
            # prepare pixel center
            self.pixel_centers = np.load(pixel_kmens_center_path)
            self.pixel_centers = np.rint(127.5 * (1 + self.pixel_centers)) # map to origin [0-255]

        self.im_preprocessor = instantiate_from_config(im_preprocessor_config)
        self.text_preprocessor = instantiate_from_config(text_preprocessor_config)
        self.text_tokenizer = instantiate_from_config(text_tokenizer_config)

        #NOTE: this is just for reconstruction with origin size
        if return_image256:
            im_preprocessor256_config={
                'target': 'image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor',
                'params':{
                    'size': 256,
                    'phase': 'val'
                }
            }
            self.im_preprocessor_256 = instantiate_from_config(im_preprocessor256_config)
        else:
            self.im_preprocessor_256 = None


        if len(image_tsv_file) == 1 and len(text_tsv_file) == 1:
            image_tsv_file = image_tsv_file[0]
            text_tsv_file = text_tsv_file[0]

        if isinstance(image_tsv_file, str) and isinstance(text_tsv_file, str):
            # single tsv file
            if (
                os.path.splitext(image_tsv_file)[1].lower() == '.tsv'
                and os.path.splitext(text_tsv_file)[1].lower() == '.tsv'
            ):
                self.image_tsv_file = TSVFile(image_tsv_file, if_generate_lineidx=True)
                self.text_tsv_file = TSVFile(text_tsv_file, if_generate_lineidx=True)
            # multiple tsv files specified in a text file
            elif (
                os.path.splitext(image_tsv_file)[1].lower() == '.txt'
                and os.path.splitext(text_tsv_file)[1].lower() == '.txt'
            ):
                self.image_tsv_file = CompositeTSVFile(image_tsv_file)
                self.text_tsv_file = CompositeTSVFile(text_tsv_file)
                self._chunk_sizes = self.image_tsv_file.get_chunk_size()
            else:
                raise ValueError("Invalid input! Please check the tsv filenames.")
        # multiple tsv files specified in a list
        elif (
            isinstance(image_tsv_file, list)
            and isinstance(text_tsv_file, list)
        ):
            assert len(image_tsv_file) == len(text_tsv_file), \
                "Inconsistent number of Image/Text tsv files!"
            self.image_tsv_file = CompositeTSVFile(image_tsv_file)
            self.text_tsv_file = CompositeTSVFile(text_tsv_file)
            self._chunk_sizes = self.image_tsv_file.get_chunk_size()
        else:
            raise ValueError("Invalid input! Please check the tsv filenames.")

        assert len(self.image_tsv_file) == len(self.text_tsv_file), \
            "Inconsistent size of Image/Text ({}/{}) data!".format(
                len(self.image_tsv_file), len(self.text_tsv_file)
            )

        self.filter_texts = filter_texts
        self._filter_data_according_text()
    
    def _filter_data_according_text(self):
        """
        Filter data according to the text
        """
        if self.indices_list_file is None:
            self._indices = list(range(len(self.image_tsv_file)))
        else:
            with open(self.indices_list_file) as f:
                lines = f.readlines()
                self._indices = [int(l) for l in lines]
                f.close()
        origin_length = len(self._indices)

        if self.filter_texts is not None:
            # filter according to the key texts
            assert isinstance(self.filter_texts, list)
            indices = []
            for idx in self._indices:
                items_text = self.text_tsv_file[idx]
                _, txt = self._decode_text(items_text)
                for and_ft in self.filter_texts: # if any element in filter_texts is in the txt, then it is valid
                    if isinstance(and_ft, str):
                        and_ft = [and_ft]
                    valid = True
                    for ft in and_ft: # all element in and_ft should be in txt
                        if ft not in txt:
                            valid = False
                            break
                    if valid:
                        indices.append(idx)
                        break
            self._indices = indices
            self.text_tsv_file.close()
        
        if self.text_tokenizer is not None:
            # filter according to the tokenizer
            indices = []
            for idx in self._indices:
                items_text = self.text_tsv_file[idx]
                _, txt = self._decode_text(items_text)
                txt_token = self.text_tokenizer.get_tokens([txt])[0]
                valid= self.text_tokenizer.check_length(txt_token)
                if valid:
                    indices.append(idx)
            self._indices = indices
            self.text_tsv_file.close()

        current_length = len(self._indices)
        if self.filter_texts is not None or self.text_tokenizer is not None:
            print('Filter data done! origin length: {}, current length: {}'.format(origin_length, current_length))
        
    def get_chunk_sizes(self):
        return self._chunk_sizes

    def get_class_boundaries(self):
        # The samples of each class are organized class-by-class.
        # _class_boundaries stores the lower- and upper-bound of each class.
        return self.image_tsv_file.get_class_boundaries()

    def get_data_for_ui_demo(self, index):
        img, txt, key = self._load_one_data(index)

        data = {
            'image': np.array(img).astype(np.uint8),
            'index': index,
            'text': txt.lower()
        }

        return data

    def get_inferior(self, image):
        """
        The inferior is infact the low resolution image, which is also
        be degraded by quantization.
        """
        def squared_euclidean_distance_np(a,b):
            b = b.T
            a2 = np.sum(np.square(a),axis=1)
            b2 = np.sum(np.square(b),axis=0)
            ab = np.matmul(a,b)
            d = a2[:,None] - 2*ab + b2[None,:]
            return d

        def color_quantize_np_topK(x, clusters,K):
            x = x.reshape(-1, 3)
            d = squared_euclidean_distance_np(x, clusters)
            # print(np.argmin(d,axis=1))
            top_K=np.argpartition(d, K, axis=1)[:,:K] 

            h,w=top_K.shape
            select_index=np.random.randint(w,size=(h))
            return top_K[range(h),select_index]

        def prior_degradation(img,clusters,prior_size,K=1): ## Downsample and random change

            LR_img_cv2=img.resize((prior_size[1], prior_size[0]), resample=Image.BILINEAR)
            LR_img_cv2=np.array(LR_img_cv2)

            token_id=color_quantize_np_topK(LR_img_cv2.astype(clusters.dtype),clusters,K)
            primers = token_id.reshape(-1,prior_size[0]*prior_size[1])
            primers_img = [np.reshape(clusters[s], [prior_size[0],prior_size[1], 3]).astype(np.uint8) for s in primers]

            degraded=Image.fromarray(primers_img[0])

            return degraded ## degraded by inferior cluster 

        h, w = image.shape[0:2]

        inferior = Image.fromarray(image.astype(np.uint8)).convert("RGB")
        inferior = prior_degradation(inferior, self.pixel_centers, self.inferior_size, K=self.inferior_random_degree)
        # inferior = inferior.resize((w, h),resample=Image.BICUBIC)
        inferior = inferior.resize((w, h),resample=Image.BILINEAR)
        inferior = np.array(inferior).astype(np.uint8)
        return inferior

    def __getitem__(self, index: Union[int, Tuple[int, int]]):
        if isinstance(index, tuple):
            #NOTE: This case is not tested!
            img, txt, key = self._load_one_data(index[0])
            if index[1] >= 0:
                tsv_filename = self.image_tsv_file.file_list[index[1]]

                # Python threads are not truly parallel. Spawn a new process instead.
                # logging.info('Pre-loading %s ...' % tsv_filename)
                # os.system('cat ' + tsv_filename + ' > /dev/null &')
                x = threading.Thread(
                   target=pre_fetch, args=(tsv_filename,), daemon=True
                )
                x.start()
        else:
            img, txt, key = self._load_one_data(index)

        # if self.transform:
        #     img = self.transform(img)
        img = np.array(img).astype(np.uint8)
        if self.im_preprocessor is not None:
            img_p = self.im_preprocessor(image=img)['image']
        else:
            img_p = img.copy()
        
        txt = txt.lower()
        if self.text_preprocessor is not None:
            txt = self.text_preprocessor(txt)
        data = {
            # 'origin_image': img,
            'image': np.transpose(img_p.astype(np.float32), (2, 0, 1)),
            'text': txt,
        }

        #NOTE: add image 256 size resolution for reconstruction
        if self.im_preprocessor_256 is not None:
            img_p256 = self.im_preprocessor_256(image=img)['image']
            data['image256'] = np.transpose(img_p256.astype(np.float32), (2, 0, 1))

        if self.load_random_mask:
            # im_size = img.shape[:2] # h, w
            im_size = [256, 256]
            mask = generate_stroke_mask(im_size=im_size,
                                        max_parts=15,
                                        maxVertex=30, #25,
                                        maxLength=100, 
                                        maxBrushWidth=24)
            # resize mask from hr to lr
            h, w = img_p.shape[0], img_p.shape[1]
            mask =  Image.fromarray(mask[:, :, 0].astype(np.uint8)).resize((w, h), resample=Image.NEAREST)
            mask = np.array(mask)[:, :, np.newaxis] # H x W x 1  
            data['mask'] = np.transpose(mask.astype(np.bool), (2, 0, 1))             

        if self.inferior_size is not None:
            inferior = self.get_inferior(img_p)
            data['inferior'] = np.transpose(inferior.astype(np.float32), (2, 0, 1))

            if random.random() < self.mask_low_to_high:
                h, w = self.inferior_size[0], self.inferior_size[1]
                mask = Image.fromarray(mask[:, :, 0]).resize((w, h), resample=Image.NEAREST) # H , W
                h, w = img_p.shape[0:2]
                mask = mask.resize((w, h), resample=Image.NEAREST)
                mask = np.array(mask)[:, :, np.newaxis] # H x W x 1
                data['mask'] = np.transpose(mask.astype(np.bool), (2, 0, 1))


        return data

    def _load_one_data(self, index):
        index = self._indices[index]
        valid = False
        count = 0
        while (not valid) and (count < 5):
            items_image = self.image_tsv_file[index]
            items_text = self.text_tsv_file[index]

            assert items_text[0] == items_image[0], 'keys do not match for image and text, {}, {}'.format(items_text[0], items_image[0])

            key, img = self._decode_image(items_image)
            # check if valid
            w, h = img.size
            ratio = w/float(h)
            if ratio < 0.5 or ratio > 2:
                count += 1
                index = random.randint(0, len(self._indices)-1)
                index = self._indices[index]
            else:
                valid = True

        key, txt = self._decode_text(items_text)

        return img, txt, key

    def _decode_image(self, items: Tuple[str, str]):
        key = items[0]
        image = Image.open(BytesIO(base64.b64decode(items[1]))).convert('RGB')

        return key, image

    def _decode_text(self, items: Tuple[str, Union[str, dict]]):
        key = items[0]
        text = ''

        if self.text_format == 'json':
            js = json.loads(items[1])
            assert 'captions' in js, '"captions" does not in {}'.format(js)
            captions = js['captions']
            if isinstance(captions, list):
                if self.num_captions == 1:
                    text = random.choice(captions)
                else:
                    text = captions
                    if len(captions) > self.num_captions:
                        text = captions[:self.num_captions]
            elif isinstance(captions, str):
                text = captions
            else:
                raise ValueError('captions should be str or list')
        else:
            text = items[1]

        return key, text

    def __len__(self):
        # return min(len(self._indices), 300) #TODO
        # return len(self.image_tsv_file)
        return len(self._indices)


class TSVTextDataset(data.Dataset):
    """
        This class is intended for encapsulating Image/Text pair data for contrastive learning described in
        the following paper,
        "Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP)
    """
    def __init__(self,
                 name,
                 text_tsv_file,
                 data_root='',
                 num_captions=1,
                 text_format='txt'):
        
        self.name = name
        self.data_root = 'data' if data_root == '' else data_root
        
        text_tsv_file = [os.path.join(self.data_root, name, tf) for tf in text_tsv_file]

        self._chunk_sizes = None
        self.num_captions = num_captions
        self.text_format = text_format

        if len(text_tsv_file) == 1:
            text_tsv_file = text_tsv_file[0]

        if isinstance(text_tsv_file, str):
            # single tsv file
            if (
                os.path.splitext(text_tsv_file)[1].lower() == '.tsv'
            ):            
                self.text_tsv_file = TSVFile(text_tsv_file, if_generate_lineidx=True)
            # multiple tsv files specified in a text file
            elif (
                os.path.splitext(text_tsv_file)[1].lower() == '.txt'
            ):            
                self.text_tsv_file = CompositeTSVFile(text_tsv_file)
                self._chunk_sizes = self.text_tsv_file.get_chunk_size()
            else:
                raise ValueError("Invalid input! Please check the tsv filenames.")
        # multiple tsv files specified in a list
        elif (
            isinstance(text_tsv_file, list)
        ):        
            self.text_tsv_file = CompositeTSVFile(text_tsv_file)
            self._chunk_sizes = self.text_tsv_file.get_chunk_size()
        else:
            raise ValueError("Invalid input! Please check the tsv filenames.")

    def get_chunk_sizes(self):
        return self._chunk_sizes

    def get_class_boundaries(self):
        # The samples of each class are organized class-by-class.
        # _class_boundaries stores the lower- and upper-bound of each class.
        return self.text_tsv_file.get_class_boundaries()

    def __getitem__(self, index):
        if isinstance(index, tuple):
            txt = self._load_one_data(index[0])
            if index[1] >= 0:
                tsv_filename = self.image_tsv_file.file_list[index[1]]

                # Python threads are not truly parallel. Spawn a new process instead.
                # logging.info('Pre-loading %s ...' % tsv_filename)
                # os.system('cat ' + tsv_filename + ' > /dev/null &')
                x = threading.Thread(
                   target=pre_fetch, args=(tsv_filename,), daemon=True
                )
                x.start()
        else:
            txt = self._load_one_data(index)
        
        data =  txt.lower()
        return data

    def _load_one_data(self, index):
        valid = False
        count = 0    
        items_text = self.text_tsv_file[index]
        _, txt = self._decode_text(items_text)

        return txt

    def _decode_text(self, items):
        key = items[0]
        text = ''

        if self.text_format == 'json':
            js = json.loads(items[1])
            assert 'captions' in js, '"captions" does not in {}'.format(js)
            captions = js['captions']
            if isinstance(captions, list):
                if self.num_captions == 1:
                    text = random.choice(captions)
                else:
                    text = captions
                    if len(captions) > self.num_captions:
                        text = captions[:self.num_captions]
            elif isinstance(captions, str):
                text = captions
            else:
                raise ValueError('captions should be str or list')
        else:
            text = items[1]

        return key, text

    def __len__(self):
        # return 100#
        return len(self.text_tsv_file)